# Import modules
import sys, math, json, csv
import cartopy.io.shapereader as shpreader
from shapely.geometry import mapping, MultiPolygon, Polygon
import xml.etree.ElementTree as ET
from myutilities import readDataFromFile

########################################################################################################################
def getBinFromValueAndCutoffs(value, cutoffs):
    NumCutoffs = len(cutoffs)
    for ii in range(NumCutoffs):
        if value < cutoffs[ii]:
            return ii
    return NumCutoffs


def getAffineTransfCoefs(origBounds, newBounds,sameRatio=False):
    # This function computes the coefficients of an affine transformation transformation such that:
    # x1 = a_x*X1 + b_x
    # x2 = a_x*X2 + b_x
    # y1 = a_y*Y1 + b_y
    # y2 = a_y*Y2 + b_y
    [X1, X2, Y1, Y2] = origBounds
    [x1, x2, y1, y2] = newBounds
    a_x = 1.*(x1-x2)/(X1-X2)
    a_y = 1.*(y1-y2)/(Y1-Y2)
    if sameRatio:
        a = min(abs(a_x), abs(a_y))
        a_x = a * a_x/abs(a_x)
        a_y = a * a_y/abs(a_y)
    b_x = x1 - a_x*X1
    b_y = y1 - a_y*Y1
    return [a_x, b_x, a_y, b_y]

# This function takes points as inputs and applies an affine transformation in their coordinates
def affineTransformation4Points(points, coefs):
    [width_a, width_b, height_a, height_b] = coefs
    return [[width_a*pp[0]+width_b, height_a*pp[1]+ height_b]  for pp in points]

# This function applies an affine transformation to all coordinates of the polygons.
def affineTransformation4Polygons(elements, coefs):
    for ii in range(0,len(elements)):
        geo = elements[ii]['geometry']
        if geo['type'] == 'Polygon':
            elements[ii]['geometry']['coordinates'][0] = affineTransformation4Points(geo['coordinates'][0], coefs)
        elif geo['type'] == 'MultiPolygon':
            for kk in range(0,len(geo['coordinates'])):
                elements[ii]['geometry']['coordinates'][kk][0] = affineTransformation4Points(geo['coordinates'][kk][0], coefs)
        else:
            raise Exception("Error.")
    return elements

def getBoundsFromShapes(shapes, epsilon=0):
    # This function takes shapes as inputs and returns the bounds (minX1,maxX1,minX2,maxX2), so I can plot a map from them.
    # When calling shape.bounds, we get: [minX1, minX2, maxX1, maxX2]
    # This function returns the mins of the mins, and the max's of the max's (for both X1 and X2),
    # and possibly ands/removes an epsilon (for padding) around these bounds.
    Inf = float('inf')
    minX1 = Inf
    minX2 = Inf
    maxX1 = -Inf
    maxX2 = -Inf
    for shape in shapes:
        mybounds = shape.bounds
        minX1 = min(minX1, mybounds[0])
        minX2 = min(minX2, mybounds[1])
        maxX1 = max(maxX1, mybounds[2])
        maxX2 = max(maxX2, mybounds[3])
    minX1 = minX1 - epsilon
    minX2 = minX2 - epsilon
    maxX1 = maxX1 + epsilon
    maxX2 = maxX2 + epsilon
    return [minX1, minX2, maxX1, maxX2]

def deep_tuple2list(x):
    # This function recursively converts tuples into lists (including tuples that are nested within tuples or lists).
    if not ( type(x) == type( () ) or type(x) == type( [] ) ):
        return x
    return list(map(deep_tuple2list,x))

# When shapes are read using shpreader.Reader: 
#	this function converts them into a dictionary with geometry / type - coordinates
def geoConvert_shapely2dict(shapes):
    shapes2 = [ {'geometry': mapping(ss.geometry), 'properties': ss.attributes} for ss in shapes]
    for sidx in range(len(shapes2)):
        shapes2[sidx]['geometry']['coordinates'] = deep_tuple2list(shapes2[sidx]['geometry']['coordinates'])
    return shapes2

def geoConvert_dict2shapely(shapes):
    class Object(object):
        pass
    shapes2 = []
    for ss in shapes:
        if ss['geometry']['type'] == 'Polygon':
            geom = Polygon(ss['geometry']['coordinates'][0])
        if ss['geometry']['type'] == 'MultiPolygon':
            geom = MultiPolygon([Polygon(pp[0]) for pp in ss['geometry']['coordinates']])
        s2 = Object()
        s2.geometry = geom
        s2.attributes = ss['properties']
        s2.bounds = geom.bounds
        shapes2.append(s2)
    return shapes2

##### Functions to plot maps in SVG format
# This function takes points as inputs, and outputs a string that defines a path between those points, in format for a SVG file
def makeSVGstring4path(points):
    NumPoints = len(points)
    shapeStr = 'M'
    for pidx in range(NumPoints):
        mypoint = points[pidx]
        px = mypoint[0]
        py = mypoint[1]
        pointStr = str(px) + ',' + str(py)
        if pidx == NumPoints-1:
            pointStr += "Z"
        else:
            pointStr += 'L'
        shapeStr += pointStr
    return shapeStr

# Function that adds a polygon to a svg element
def addSVGpath4Polygon(parentNode, polygon, properties):
    vertices = polygon.exterior.coords.xy
    [x,y] = vertices
    points = list(zip(x, y))
    shapeStr = makeSVGstring4path(points)
    mypath = ET.SubElement(parentNode, 'path')
    mypath.set('d', shapeStr)
    for kk in properties:
        mypath.set(kk, properties[kk])

def plotMapSVG(svgObject, outputFile):
    svgHeight = svgObject['height']
    svgWidth = svgObject['width']

    # Create a svg element
    svg = ET.Element('svg', {'height': str(svgHeight), 'width':str(svgWidth), 'xmlns': 'http://www.w3.org/2000/svg'})
    g1 = ET.SubElement(svg, 'g')
    
    # Add the shapes
    if 'shapes' in svgObject:
        svgShapes = svgObject['shapes']
        for sidx in range(len(svgShapes)):
            svgShape = svgShapes[sidx]
            myshape = svgShape['shape']
            myproperties = svgShape['properties']
            geom = myshape.geometry
            if geom.geom_type == 'Polygon':
                addSVGpath4Polygon(g1, geom, myproperties)
            if geom.geom_type == 'MultiPolygon':
                for polygon in geom.geoms:
                    addSVGpath4Polygon(g1, polygon, myproperties)
    
    # Add points
    if 'points' in svgObject:
        svgPoints = svgObject['points']
        for pidx in range(len(svgPoints)):
            mypoint = svgPoints[pidx]
            coords = mypoint['coords']
            pointProps = mypoint['pointProps']
            mycircle = ET.SubElement(g1, 'circle', {'cx': str(coords[0]), 'cy': str(coords[1])})
            for kk in pointProps:
                mycircle.set(kk, pointProps[kk])
            if 'label' in mypoint:
                labelProps = mypoint['labelProps']
                mytext = ET.SubElement(g1, 'text', {'x': str(coords[0]), 'y': str(-10+coords[1])})
                mytext.text = mypoint['label']
                for kk in labelProps:
                    mytext.set(kk, labelProps[kk])
    
    # Add the title
    if 'title' in svgObject:
        svgTitle = svgObject['title']
        mytext = ET.SubElement(svg, 'text', {'x': svgTitle['xPosition'], 'y': svgTitle['yPosition']})
        mytext.text = svgTitle['label']
        props = svgTitle['properties']
        for kk in props:
            mytext.set(kk, props[kk])
    
    # Add the legend
    if 'legend' in svgObject:
        svgLegend = svgObject['legend']
        mylocation = svgLegend['location']
        heightPerItem = svgLegend['heightPerItem']
        widthPerSymbol = svgLegend['widthPerSymbol']
        items = svgLegend['items']
        NumItems = len(items)
        if mylocation == 'southwest':
            x_cur = svgLegend['x_offset']
            y_cur = svgHeight - NumItems*heightPerItem
        for item in items:
            # Rectangles
            if item['type'] == 'rect':
                mytext = ET.SubElement(svg, 'text', {'y': str(y_cur+10), 'x': str(x_cur + widthPerSymbol + 10)})
                mytext.text = item['label']
                myrect = ET.SubElement(svg, 'rect', {'y': str(y_cur), 'x': str(x_cur), 'width':str(widthPerSymbol), 'height': str(heightPerItem/1.5)})
                for kk in item['props']:
                    myrect.set(kk, item['props'][kk])
            
            # Dots of multiple sizes (dot area is proportional to value...)
            elif item['type'] == 'dotMultipleSizes':
                maxRadius = max([ss['radius'] for ss in item['subitems']])
                subitems  = item['subitems']
                # Add the dots
                for subitem in subitems:
                    myradius = subitem['radius']
                    mycircle = ET.SubElement(svg, 'circle', {'cx': str(x_cur + maxRadius), 'cy': str(y_cur + maxRadius-myradius), 'r': str(myradius)})
                    for kk in subitem['props']:
                        mycircle.set(kk, subitem['props'][kk])
                # Add the labels
                for subitem in subitems:
                    mytext = ET.SubElement(svg, 'text', {'y': str(y_cur - maxRadius + 2*myradius - 5), 'x': str(x_cur + widthPerSymbol + 5), 'font-size':'9px'})
                    mytext.text = subitem['label']
                    myradius = subitem['radius']
                # Add the lines
                for subitem in subitems:
                    myradius = subitem['radius']
                    myY= str(y_cur + maxRadius - 2*myradius)
                    myX1 = x_cur + maxRadius
                    myX2 = x_cur + widthPerSymbol + 0
                    myline = ET.SubElement(svg, 'line', {'x1':str(myX1), 'x2':str(myX2), 'y1':myY, 'y2':myY, 'width':'2', 'stroke':'black', 'stroke-width': "0.2"})
            y_cur = y_cur + heightPerItem
            
    # Output to file
    mystring = ET.tostring(svg)
    fout = open(outputFile, "wb")
    fout.write(mystring)
    fout.close()
########################################################################################################################


# Read command line arguments
[outputFile, regionShpFile, bigShpFile, dataFile, titleLabel, cutoffJsonFile] = sys.argv[1:7]
if len(sys.argv) > 7:
    NumDecimals4legend = int(sys.argv[7])
else:
    NumDecimals4legend = 1

# Set the dimensions of the image
svgWidth = 700;
svgHeight = svgWidth  - 150;
paddings = [160,10,40,40] # left, top, right, bottom

# Read shape files
if bigShpFile != '':
    bigShapes = list(shpreader.Reader(bigShpFile).records())
else:
    bigShapes = []
regionShapes = list(shpreader.Reader(regionShpFile).records())
NumRegions = len(regionShapes)

# Read file that associates a value to each region
[regionIds, regionValues] = readDataFromFile(dataFile, ',', True, [1,2])
regionValues = [float(x) for x in regionValues]
fin = open(dataFile, 'r')
header = fin.readline()
shapeIdField = header.split(',')[0]
valueIdField = header.split(',')[1].rstrip()
fin.close()

### Apply an affine transformation from original projection (e.g. Lambert 93) to pixels
coordsBounds = getBoundsFromShapes(regionShapes)
coordsBounds = [coordsBounds[ii] for ii in [0,2,1,3]]
affineTransf_coefs = getAffineTransfCoefs(coordsBounds, [paddings[0], svgWidth-paddings[2], svgHeight-paddings[1], paddings[3]], True)
bigShapes = geoConvert_dict2shapely(affineTransformation4Polygons(geoConvert_shapely2dict(bigShapes), affineTransf_coefs))
regionShapes = geoConvert_dict2shapely(affineTransformation4Polygons(geoConvert_shapely2dict(regionShapes), affineTransf_coefs))


### Put together the regions (coords and value)
regionsData = []
for ii in range(NumRegions):
    myshape = regionShapes[ii]
    myregionId = myshape.attributes[shapeIdField]
    try:
        regidx = regionIds.index(myregionId)
        myRegionValue = regionValues[regidx]
    except:
        myRegionValue = 0
        print('Could not find region id: ' + str(myregionId) + ' in the region data file')
#        raise Exception('Could not find region id: ' + str(myregionId) + ' in the region data file')
    regionsData.append({'shape':myshape, 'value': myRegionValue})


#################### Define colors and cutoffs for the regions ####################
# Define color palette
def rgb_to_hex(rgb):
    return '%02x%02x%02x' % rgb
#colors = [(255,245,235),(254,230,206),(253,208,162),(253,174,107),(253,141,60),(241,105,19),(217,72,1),(166,54,3),(127,39,4)]
colors = [(255,255,255),(225,225,225),(195,195,195),(165,165,165),(135,135,135),(105,105,105),(75,75,75),(45,45,45),(15,15,15)]
colors = ["#" + rgb_to_hex(cc) for cc in colors]

# Define cutoffs and function that maps value to color using those cutoffs
with open(cutoffJsonFile) as json_file:
    a = json.load(json_file)
cutoffs = a[valueIdField]

def getColor(regionValue):
    return colors[getBinFromValueAndCutoffs(regionValue, cutoffs)]

#################### Make svgObject and output file ####################
#### Add shapes
svgShapes = []
for sidx in range(len(bigShapes)):
    myshape = bigShapes[sidx]
    myproperties = {'fill': 'white', 'fill-opacity': '0.0', 'stroke': 'black', 'stroke-width': "2"}
    svgShapes.append({'shape': myshape, 'properties': myproperties})

for ii in range(NumRegions):
    myregion = regionsData[ii]
    myshape = myregion['shape']
    myRegionValue = myregion['value']
    myshape_color = getColor(myRegionValue)
    myproperties = {'fill': myshape_color, 'stroke': 'black', 'stroke-width': "0.2"}
    svgShapes.append({'shape': myshape, 'properties': myproperties})

#### Add legend
svgLegend = {'location': 'southwest', 'heightPerItem': 20, 'x_offset': 10 , 'widthPerSymbol' : 30}
svgLegendItems = []
for cidx in range(len(colors)):
    mycolor = colors[cidx]
    if NumDecimals4legend == 1:
        if cidx == 0:
            mylabel = '< %.1f' % (cutoffs[cidx])
        elif cidx < len(colors) - 1:
            mylabel = '%.1f - %.1f' % (cutoffs[cidx-1], cutoffs[cidx])
        else:
            mylabel = '%.1f +' % (cutoffs[cidx-1])
    elif NumDecimals4legend == 2:
        if cidx == 0:
           mylabel = '< %.2f' % (cutoffs[cidx])
        elif cidx < len(colors) - 1:
            mylabel = '%.2f - %.2f' % (cutoffs[cidx-1], cutoffs[cidx])
        else:
            mylabel = '%.2f +' % (cutoffs[cidx-1])
    else:
        raise Exception('Unknown number of digits')
    svgLegendItems.append({'label': mylabel, 'type': 'rect', 'props': {'fill': mycolor}})

svgLegend['items'] = svgLegendItems

#### Add title
svgTitle = {'label': titleLabel, 'xPosition': "50%", 'yPosition': '30', 'properties': {'font-size': '30', 'font-family':"Arial"}}


#### Make the plot
svgObject = {'shapes': svgShapes, 'points': [], 'height': svgHeight, 'width': svgWidth, 'title': svgTitle, 'legend': svgLegend}
plotMapSVG(svgObject, outputFile)
